{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.random as rd\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ask the ***now_tree_max_len*** first !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BinarySearchTree:\n",
    "    \"\"\"Binary Tree for PER\n",
    "    Contributor: Github GyChou\n",
    "    Reference:\n",
    "    https://github.com/kaixindelele/DRLib/tree/main/algos/pytorch/td3_sp\n",
    "    https://github.com/jaromiru/AI-blog/blob/master/SumTree.py\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, memo_len):\n",
    "        self.memo_len = memo_len  # replay buffer len\n",
    "        # SumTree size is 2 * buffer_len - 1, parent nodes is buffer_len-1, and leaves node is buffer_len.\n",
    "        self.ps_tree = np.zeros(2 * memo_len - 1)\n",
    "        self.now_max_tree_len = self.memo_len - 1\n",
    "        self.indices = None\n",
    "\n",
    "        self.per_alpha = 0.6\n",
    "        self.per_beta = 0.4\n",
    "        self.per_beta_increment_per_sampling = 0.001\n",
    "        self.depth = int(np.log2(2 * memo_len - 1))\n",
    "\n",
    "    def update(self, data_idx, prob=10):  # 10 is max_prob\n",
    "        tree_idx = data_idx + self.memo_len - 1\n",
    "        if self.now_max_tree_len == tree_idx:\n",
    "            self.now_max_tree_len += 1\n",
    "\n",
    "        delta = prob - self.ps_tree[tree_idx]\n",
    "        self.ps_tree[tree_idx] = prob\n",
    "\n",
    "        while tree_idx != 0:  # propagate the change through tree\n",
    "            tree_idx = (tree_idx - 1) // 2  # faster than the recursive loop\n",
    "            self.ps_tree[tree_idx] += delta\n",
    "\n",
    "    def update_(self, data_ids, prob=10):  # 10 is max_prob\n",
    "        ids = data_ids + self.memo_len - 1\n",
    "        self.now_max_tree_len += (ids >= self.now_max_tree_len).sum()\n",
    "\n",
    "        upper_step = self.depth - 1\n",
    "        self.ps_tree[ids] = prob  # here, ids means the indices of \n",
    "                                  # given children (maybe the right ones or left ones)\n",
    "        p_ids = (ids - 1) // 2\n",
    "\n",
    "        while upper_step:  # propagate the change through tree\n",
    "            ids = p_ids * 2 + 1  # in this while loop, ids means the indices of the left children\n",
    "            self.ps_tree[p_ids] = self.ps_tree[ids] + self.ps_tree[ids + 1]\n",
    "            p_ids = (p_ids - 1) // 2\n",
    "            upper_step -= 1\n",
    "\n",
    "        self.ps_tree[0] = self.ps_tree[1] + self.ps_tree[2]  # because we take depth-1 upper steps, ps_tree[0] need to be updated alone\n",
    "\n",
    "    def get_leaf_id(self, v):\n",
    "        \"\"\"\n",
    "        Tree structure and array storage:\n",
    "        Tree index:\n",
    "              0       -> storing priority sum\n",
    "            |  |\n",
    "          1     2\n",
    "         | |   | |\n",
    "        3  4  5  6    -> storing priority for transitions\n",
    "        Array type for storing:\n",
    "        [0,1,2,3,4,5,6]\n",
    "        \"\"\"\n",
    "        parent_idx = 0\n",
    "        while True:  # the while loop is faster than the method in the reference code\n",
    "            l_idx = 2 * parent_idx + 1  # the leaf's left node\n",
    "            r_idx = l_idx + 1  # the leaf's right node\n",
    "            if l_idx >= (len(self.ps_tree)):  # reach bottom, end search\n",
    "                leaf_idx = parent_idx\n",
    "                break\n",
    "            else:  # downward search, always search for a higher priority node\n",
    "                if v <= self.ps_tree[l_idx]:\n",
    "                    parent_idx = l_idx\n",
    "                else:\n",
    "                    v -= self.ps_tree[l_idx]\n",
    "                    parent_idx = r_idx\n",
    "        return min(leaf_idx, self.now_max_tree_len - 2)  # leaf_idx\n",
    "\n",
    "    def get_indices_is_weights(self, batch_size, beg, end):\n",
    "        self.per_beta = np.min([1., self.per_beta + self.per_beta_increment_per_sampling])  # max = 1\n",
    "\n",
    "        # get random values for searching indices with proportional prioritization\n",
    "        values = (rd.rand(batch_size) + np.arange(batch_size)) * (self.ps_tree[0] / batch_size)\n",
    "\n",
    "        # get proportional prioritization\n",
    "        leaf_ids = np.array([self.get_leaf_id(v) for v in values])\n",
    "        self.indices = leaf_ids - (self.memo_len - 1)\n",
    "\n",
    "        probs = self.ps_tree[leaf_ids] / self.ps_tree[beg:end].min()\n",
    "        is_weights = np.power(probs, -self.per_beta)  # important sampling weights\n",
    "        return self.indices, is_weights\n",
    "\n",
    "    def td_error_update(self, td_error):  # td_error = (q-q).detach_().abs()\n",
    "        prob = td_error.clamp(1e-6, 10).pow(self.per_alpha)\n",
    "        prob = prob.cpu().numpy()\n",
    "        for data_idx, p in zip(self.indices, prob):\n",
    "            self.update(data_idx, p)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_tree1 = BinarySearchTree(1000)\n",
    "my_tree2 = BinarySearchTree(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(999, 999)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_tree1.now_max_tree_len, my_tree2.now_max_tree_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ids = np.random.randint(0, 1000, size=(1000,))\n",
    "ids = np.random.randint(0, 1000, size=(1000,))\n",
    "ids = np.unique(ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time for update: 0.028980016708374023\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "\n",
    "for i in ids:\n",
    "    my_tree1.update(i)\n",
    "\n",
    "t1 = time.time()\n",
    "normal_time = t1 - t0\n",
    "print(\"time for update: {}\".format(normal_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time for numpy update: 0.001996755599975586\n",
      "speedup ratio: 14.51355223880597\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "\n",
    "my_tree2.update_(ids)\n",
    "\n",
    "t1 = time.time()\n",
    "numpy_time = t1 - t0\n",
    "print(\"time for numpy update: {}\".format(numpy_time))\n",
    "print(\"speedup ratio: {}\".format(normal_time / numpy_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.square((my_tree1.ps_tree - my_tree2.ps_tree)).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** ask jh whether the now_max_tree_len is expected ***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1001, 1631)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_tree1.now_max_tree_len, my_tree2.now_max_tree_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
